Skip to content

标记数据 (Flagging)

数据标记是 Gradio 中一个重要的功能,允许用户在使用您的机器学习模型时,对特定的输入-输出对进行标记。这种功能在收集反馈、识别模型错误或构建标注数据集时非常有用。

标记按钮介绍

默认情况下,每个 Gradio Interface 的输出组件下方都会显示一个"标记"(Flag)按钮。当用户看到有趣的、意外的或错误的结果时,他们可以点击这个按钮,将当前的输入和输出数据发送回运行演示的服务器。这些标记的数据可以被保存用于后续的模型改进。

配置标记行为

gr.Interface 构造函数中,有四个关键参数用于控制标记功能的行为:

1. flagging_mode

该参数控制标记按钮的显示方式和标记行为:

  • "manual" (默认): 用户将看到一个标记按钮,只有在点击按钮时才会标记样本。
  • "auto": 用户不会看到标记按钮,但每个提交的样本都会自动标记。
  • "never": 禁用标记功能,用户不会看到标记按钮,也不会标记任何样本。
python
import gradio as gr

def calculator(num1, operation, num2):
    if operation == "add":
        return num1 + num2
    elif operation == "subtract":
        return num1 - num2
    elif operation == "multiply":
        return num1 * num2
    elif operation == "divide":
        if num2 == 0:
            raise gr.Error("不能除以零!")
        return num1 / num2

# 自动标记所有提交的计算
demo = gr.Interface(
    calculator,
    ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
    "number",
    flagging_mode="auto"
)

demo.launch()

2. flagging_options

该参数允许您自定义标记的原因选项:

  • 如果为 None(默认),用户只需点击"标记"按钮,不显示其他选项。
  • 如果提供字符串列表,用户将看到多个标记按钮,每个按钮对应提供的字符串。例如,["不正确", "有歧义"] 将显示"标记为不正确"和"标记为有歧义"按钮。
  • 用户选择的选项将与输入和输出一起记录在标记数据中。
python
# 提供标记原因选项
demo = gr.Interface(
    calculator,
    ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
    "number",
    flagging_mode="manual",
    flagging_options=["计算错误", "除以零", "其他问题"]
)

也可以使用元组列表提供自定义的标签和值:

python
# 使用自定义标签和值
demo = gr.Interface(
    calculator,
    ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
    "number",
    flagging_mode="manual",
    flagging_options=[
        ("结果不正确", "incorrect_result"),
        ("操作不支持", "unsupported_operation"),
        ("其他问题", "other_issue")
    ]
)

3. flagging_dir

该参数指定存储标记数据的目录:

python
demo = gr.Interface(
    calculator,
    ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
    "number",
    flagging_dir="./my_flagged_data"  # 自定义标记数据存储目录
)

如果不指定,默认为 ./.gradio/flagged/

4. flagging_callback

此参数允许您使用自定义的回调函数来处理标记的数据,而不是默认的 CSV 记录方式:

python
import gradio as gr
from gradio.flagging import FlaggingCallback

class MyCustomFlaggingCallback(FlaggingCallback):
    def setup(self, components, flagging_dir):
        # 初始化设置,例如连接到数据库
        self.log_file = open(f"{flagging_dir}/custom_logs.txt", "a")
        return self
    
    def flag(self, flag_data, flag_option=None):
        # 处理标记数据
        data_str = ", ".join([str(d) for d in flag_data])
        if flag_option:
            self.log_file.write(f"标记原因: {flag_option}, 数据: {data_str}\n")
        else:
            self.log_file.write(f"数据: {data_str}\n")
        self.log_file.flush()
        return

demo = gr.Interface(
    calculator,
    ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
    "number",
    flagging_callback=MyCustomFlaggingCallback()
)

标记数据的存储格式

当用户点击标记按钮时,数据将按照以下方式存储:

基本数据存储

对于基本的原始数据(数字、文本等),数据将存储在一个 CSV 文件中:

# <flagging_dir>/logs.csv
num1,operation,num2,Output,timestamp
5,add,7,12,2022-01-31 11:40:51.093412
6,subtract,1.5,4.5,2022-01-31 03:25:32.023542

文件数据存储

如果您的接口包含文件类型的输入或输出(如图像、音频等),这些文件将单独保存,CSV 文件中只存储文件路径:

# 目录结构
+-- flagged/
|   +-- logs.csv
|   +-- image/
|   |   +-- 0.png
|   |   +-- 1.png
|   +-- Output/
|   |   +-- 0.png
|   |   +-- 1.png

# <flagging_dir>/logs.csv
image,Output,timestamp
image/0.png,Output/0.png,2022-02-04 19:49:58.026963
image/1.png,Output/1.png,2022-02-02 10:40:51.093412

带有标记选项的数据

如果您使用了 flagging_options,被选择的选项也会记录在 CSV 文件中:

# <flagging_dir>/logs.csv
num1,operation,num2,Output,flag,timestamp
5,add,7,-12,计算错误,2022-02-04 11:40:51.093412
6,subtract,1.5,3.5,其他问题,2022-02-04 11:42:32.062512

在 Blocks 中使用标记功能

gr.Blocks() 中,您也可以实现标记功能,但这需要手动设置:

python
import gradio as gr
import numpy as np

def sepia(input_img):
    sepia_filter = np.array([
        [0.393, 0.769, 0.189],
        [0.349, 0.686, 0.168],
        [0.272, 0.534, 0.131]
    ])
    sepia_img = input_img.dot(sepia_filter.T)
    sepia_img /= sepia_img.max()
    return sepia_img

# 创建 CSV 记录器
callback = gr.CSVLogger()

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            img_input = gr.Image()
            transform_btn = gr.Button("应用滤镜")
        img_output = gr.Image()
    
    flag_btn = gr.Button("标记此结果")
    
    # 设置 logger
    callback.setup([img_input, img_output], "flagged_data")
    
    # 设置事件
    transform_btn.click(sepia, inputs=img_input, outputs=img_output)
    
    # 标记按钮点击事件
    flag_btn.click(
        lambda img_in, img_out: callback.flag([img_in, img_out]), 
        [img_input, img_output], 
        None,
        preprocess=False
    )

demo.launch()

标记数据的应用

通过标记功能收集的数据可以用于多种目的:

  1. 识别模型错误:标记功能可以帮助您收集模型表现不佳的数据点。
  2. 创建测试集:将收集到的难以处理的样本组织成测试集,用于模型评估。
  3. 改进模型:使用标记的数据进行模型的再训练或微调。
  4. 数据审计:检查模型在不同输入上的表现,识别潜在的偏见。

隐私考虑

使用标记功能时,请确保:

  1. 告知用户他们的数据何时会被保存。
  2. 明确您将如何使用这些标记的数据。
  3. 在使用 flagging_mode="auto" 时尤其要注意,因为所有用户提交的数据都会被自动保存。

结论

Gradio 的标记功能是一个强大的工具,可以帮助您收集反馈并改进模型。通过合理配置标记选项,您可以收集到有针对性的反馈,更有效地识别和解决模型中的问题。

在下一章中,我们将介绍如何在 Interface 中管理状态,这对于创建具有记忆功能的应用程序至关重要。